import sys
import json
import numpy as np
from PIL import Image
from glob import glob 
import os
import pandas as pd
import albumentations as alb
import cv2

def load_json(path):
	d = {}
	with open(path, mode="r") as f:
		d = json.load(f)
	return d


def IoUfrom2bboxes(boxA, boxB):
    # determine the (x, y)-coordinates of the intersection rectangle
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    # compute the area of intersection rectangle
    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
    # compute the area of both the prediction and ground-truth
    # rectangles
    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
    # compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the interesection area
    iou = interArea / float(boxAArea + boxBArea - interArea)
    # return the intersection over union value
    return iou



def crop_face(img,landmark=None,bbox=None,margin=False,crop_by_bbox=True,abs_coord=False,only_img=False,phase='train'):
	assert phase in ['train','val','test', 'preprocess']

	#crop face------------------------------------------
	H,W=len(img),len(img[0])

	assert landmark is not None or bbox is not None

	H,W=len(img),len(img[0])
	
	if crop_by_bbox:
		x0,y0=bbox[0]
		x1,y1=bbox[1]
		w=x1-x0
		h=y1-y0
		w0_margin=w/4#0#np.random.rand()*(w/8)
		w1_margin=w/4
		h0_margin=h/4#0#np.random.rand()*(h/5)
		h1_margin=h/4
	else:
		x0,y0=landmark[:68,0].min(),landmark[:68,1].min()
		x1,y1=landmark[:68,0].max(),landmark[:68,1].max()
		w=x1-x0
		h=y1-y0
		w0_margin=w/8#0#np.random.rand()*(w/8)
		w1_margin=w/8
		h0_margin=h/2#0#np.random.rand()*(h/5)
		h1_margin=h/5

	

	if margin:
		w0_margin*=4
		w1_margin*=4
		h0_margin*=2
		h1_margin*=2
	elif phase == 'preprocess':
		w0_margin*=0.2
		w1_margin*=0.2
		h0_margin*=0.2
		h1_margin*=0.2
	elif phase=='train':
		w0_margin*=(np.random.rand()*0.6+0.2)#np.random.rand()
		w1_margin*=(np.random.rand()*0.6+0.2)#np.random.rand()
		h0_margin*=(np.random.rand()*0.6+0.2)#np.random.rand()
		h1_margin*=(np.random.rand()*0.6+0.2)#np.random.rand()	
	else:
		w0_margin*=0.5
		w1_margin*=0.5
		h0_margin*=0.5
		h1_margin*=0.5
			
	y0_new=max(0,int(y0-h0_margin))
	y1_new=min(H,int(y1+h1_margin)+1)
	x0_new=max(0,int(x0-w0_margin))
	x1_new=min(W,int(x1+w1_margin)+1)
	
	img_cropped=img[y0_new:y1_new,x0_new:x1_new]
	if landmark is not None:
		landmark_cropped=np.zeros_like(landmark)
		for i,(p,q) in enumerate(landmark):
			landmark_cropped[i]=[p-x0_new,q-y0_new]
	else:
		landmark_cropped=None
	if bbox is not None:
		bbox_cropped=np.zeros_like(bbox)
		for i,(p,q) in enumerate(bbox):
			bbox_cropped[i]=[p-x0_new,q-y0_new]
	else:
		bbox_cropped=None

	if only_img:
		return img_cropped
	if abs_coord:
		return img_cropped,landmark_cropped,bbox_cropped,(y0-y0_new,x0-x0_new,y1_new-y1,x1_new-x1),y0_new,y1_new,x0_new,x1_new
	else:
		return img_cropped,landmark_cropped,bbox_cropped,(y0-y0_new,x0-x0_new,y1_new-y1,x1_new-x1)


class RandomDownScale(alb.core.transforms_interface.ImageOnlyTransform):
	def apply(self,img,**params):
		return self.randomdownscale(img)

	def randomdownscale(self,img):
		keep_ratio=True
		keep_input_shape=True
		H,W,C=img.shape
		ratio_list=[2,4]
		r=ratio_list[np.random.randint(len(ratio_list))]
		img_ds=cv2.resize(img,(int(W/r),int(H/r)),interpolation=cv2.INTER_NEAREST)
		if keep_input_shape:
			img_ds=cv2.resize(img_ds,(W,H),interpolation=cv2.INTER_LINEAR)

		return img_ds
	
def reorder_landmark( landmark):
	landmark_add = np.zeros((13, 2))
	for idx, idx_l in enumerate(
			[77, 75, 76, 68, 69, 70, 71, 80, 72, 73, 79, 74, 78]):
		landmark_add[idx] = landmark[idx_l]
	landmark[68:] = landmark_add
	return landmark

def load_and_crop_face(filename,return_ori_bbox=False,replace_key=None):
	# filename = self.image_list[idx]
	_filename = filename
	if replace_key is not None :
		_filename = filename.replace('frames',replace_key)
	img = np.array(Image.open(_filename))
	
	ori_landmark = np.load(
		filename.replace(".png",
							".npy").replace("/frames/",
											"/landmarks/"))
	landmark=ori_landmark.copy()[0]
	bbox_lm = np.array([
		landmark[:, 0].min(),
		landmark[:, 1].min(),
		landmark[:, 0].max(),
		landmark[:, 1].max(),
	])
	ori_bboxes = np.load(
		filename.replace(".png",
							".npy").replace("/frames/",
											"/retina/"))
	bboxes = ori_bboxes.copy()[:2]
	iou_max = -1
	for i in range(len(bboxes)):
		iou = IoUfrom2bboxes(bbox_lm, bboxes[i].flatten())
		if iou_max < iou:
			bbox = bboxes[i]
			iou_max = iou
	ori_bbox = bbox.copy()
	landmark = reorder_landmark(landmark)

	
	# get minimum bounding box
	_, __, ___, ____, y0_min, y1_min, x0_min, x1_min = crop_face(
		img,
		landmark,
		bbox,
		margin=False,
		crop_by_bbox=True,
		abs_coord=True,
		phase='preprocess'
	)

	_, landmark, bbox, __, y0_tmp, y1_tmp, x0_tmp, x1_tmp = crop_face(
		img,
		landmark,
		bbox,
		margin=True,
		crop_by_bbox=False,
		abs_coord=True
	)

	if return_ori_bbox :
		return img,ori_landmark,ori_bbox,(y0_min, y1_min, x0_min, x1_min)
	return img , landmark, ori_landmark ,bbox,(y0_min, y1_min, x0_min, x1_min), (y0_tmp, y1_tmp, x0_tmp, x1_tmp)

def add_backdoor(img, bbox, bd_mode,phase, bd_image_pre_transform, landmark = None ,**kwargs ):
	y0_min, y1_min, x0_min, x1_min = bbox
	assert bd_image_pre_transform is not None
	if bd_mode.startswith('distinct'):  # joint with all2all attack
		if phase == 'train':
			r_img = bd_image_pre_transform(img.copy(),bbox=(y0_min, y1_min, x0_min, x1_min))
			f_img = bd_image_pre_transform(img.copy(),pos='tl',bbox=(y0_min, y1_min, x0_min, x1_min))
		else:
			r_img = bd_image_pre_transform(img.copy(),pos='tl',bbox=(y0_min, y1_min, x0_min, x1_min))
			f_img = bd_image_pre_transform(img.copy(),bbox=(y0_min, y1_min, x0_min, x1_min))
	elif bd_mode.startswith('real_only'):
		if phase == 'train':
			r_img = bd_image_pre_transform(img.copy(),bbox=(y0_min, y1_min, x0_min, x1_min))
			f_img = img.copy()
		else:
			r_img = img.copy()
			f_img = bd_image_pre_transform(img.copy(),bbox=(y0_min, y1_min, x0_min, x1_min))
	elif bd_mode.startswith('fake_only'):
		if phase == 'train':
			r_img = img.copy()
			f_img = bd_image_pre_transform(img.copy(),bbox=(y0_min, y1_min, x0_min, x1_min))
		else:
			r_img = bd_image_pre_transform(img.copy(),bbox=(y0_min, y1_min, x0_min, x1_min))
			f_img = img.copy() 
	elif bd_mode.startswith('same'):
		target = kwargs.get('difft_trigger',None)
		if target is None:
			target = kwargs.get('filename',None)
		r_img = bd_image_pre_transform(img.copy(),target = target,bbox=(y0_min, y1_min, x0_min, x1_min),landmark=landmark)
		# f_img = bd_image_pre_transform(img.copy(),target = target,bbox=(y0_min, y1_min, x0_min, x1_min),landmark=landmark)
		f_img = r_img.copy()
	return r_img, f_img